import sys
sys.path.insert(0, '.')
sys.path.insert(0, '..')

import torch
import torch.nn.functional as F
from torch import optim, nn
import numpy as np
import os, argparse, copy, json
import pickle as pkl
from scipy.spatial.transform import Rotation as R
from psbody.mesh import Mesh
from manopth.manolayer import ManoLayer

from utils import *
import utils
import utils.model_util as model_util
from utils.anchor_utils import masking_load_driver, anchor_load_driver, recover_anchor_batch


def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, nn_hand_params=24):
  # obj_verts_trans, obj_faces
  joints = torch.from_numpy(joints).float().cuda() # joints 
  base_pts = torch.from_numpy(base_pts).float().cuda() # base_pts 
  
  if nn_hand_params < 45:
    use_pca = True
  else:
    use_pca = False
  
  tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda()
  tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float().cuda()

  # setup MANO layer
  mano_path = "./data/mano_models/mano/models"
  nn_hand_params = 24
  use_pca = True
  if not use_left:
    mano_layer = ManoLayer(
        flat_hand_mean=False,
        side='right',
        mano_root=mano_path, # mano_root #
        ncomps=nn_hand_params, # hand params # 
        use_pca=use_pca, # pca for pca #
        root_rot_mode='axisang',
        joint_rot_mode='axisang'
    ).cuda()
  else:
    mano_layer = ManoLayer(
        flat_hand_mean=False,
        side='left',
        mano_root=mano_path, # mano_root #
        ncomps=nn_hand_params, # hand params # 
        use_pca=use_pca, # pca for pca #
        root_rot_mode='axisang',
        joint_rot_mode='axisang'
    ).cuda()

  nn_frames = joints.size(0)
  
  
  # initialize variables
  beta_var = torch.randn([1, 10]).cuda()
  # first 3 global orientation
  rot_var = torch.randn([nn_frames, 3]).cuda()
  theta_var = torch.randn([nn_frames, nn_hand_params]).cuda()
  transl_var = torch.randn([nn_frames, 3]).cuda()

  beta_var.requires_grad_()
  rot_var.requires_grad_()
  theta_var.requires_grad_()
  transl_var.requires_grad_()
  
  learning_rate = 0.1

  s = 1.0

  num_iters = 200
  opt = optim.Adam([rot_var, transl_var], lr=learning_rate)
  for i in range(num_iters): #
      opt.zero_grad()
      # mano_layer #
      hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
          beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
      hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
      hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
      
      joints_pred_loss = torch.sum(
        (hand_joints - joints) ** 2, dim=-1
      ).mean()
      
      # opt.zero_grad()
      pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[:, 1:], theta_var.view(nn_frames, -1)[:, :-1])
      # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
      shape_prior_loss = torch.mean(beta_var**2)
      pose_prior_loss = torch.mean(theta_var**2)
      
      loss = joints_pred_loss * 1000
      
      opt.zero_grad()
      loss.backward()
      opt.step()
      
      print('Iter {}: {}'.format(i, loss.item()), flush=True)
      print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
      print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
      print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
      print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
  
  
  num_iters = 3000
  learning_rate = 0.01
  opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=learning_rate)
  scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
  for i in range(num_iters):
      opt.zero_grad()
      # mano_layer
      hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
          beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
      hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
      hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
      
      joints_pred_loss = torch.sum(
        (hand_joints - joints) ** 2, dim=-1
      ).mean()
      
      pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
      # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
      shape_prior_loss = torch.mean(beta_var**2)
      pose_prior_loss = torch.mean(theta_var**2)
      joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
      loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.


      opt.zero_grad()
      loss.backward()
      opt.step()
      scheduler.step()
      
      print('Iter {}: {}'.format(i, loss.item()), flush=True)
      print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
      print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
      print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
      print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
      print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
      

  hand_verts = hand_verts.detach().cpu().numpy()
  hand_joints = hand_joints.detach().cpu().numpy()
  rot_var = rot_var.detach().cpu().numpy()
  theta_var = theta_var.detach().cpu().numpy()
  beta_var = beta_var.detach().cpu().numpy()
  transl_var = transl_var.detach().cpu().numpy()
  
  optimized_dict = {
    'hand_verts': hand_verts,
    'hand_joints': hand_joints,
    'rot_var': rot_var,
    'theta_var': theta_var,
    'beta_var': beta_var,
    'transl_var': transl_var,
  }
  
  
  return optimized_dict



def get_rhand_joints_verts_fr_params(rhand_transl, rhand_rot, rhand_theta, rhand_beta):
  # setup MANO layer
  mano_path = "./data/mano_models/mano/models"
  mano_layer = ManoLayer(
      flat_hand_mean=True,
      side='right',
      mano_root=mano_path, # mano_root #
      ncomps=24,
      use_pca=True,
      root_rot_mode='axisang',
      joint_rot_mode='axisang'
  ).cuda() 
  nn_frames = rhand_rot.size(0) #### rhand_rot for glboal orientation ###
  # nframes for the joitns and th
  hand_verts, hand_joints = mano_layer(torch.cat([rhand_rot, rhand_theta], dim=-1),
      rhand_beta.view(-1, 10), rhand_transl)
  hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
  hand_joints = hand_joints.view( nn_frames, -1, 3) * 0.001
  return hand_verts.detach().cpu().numpy(), hand_joints.detach().cpu().numpy()




def get_arctic_seq_paths():
    processed_arctic_root = "./data/arctic_processed_data/processed_seqs"
    subj_folders = os.listdir(processed_arctic_root)
    tot_arctic_seq_paths = []
    tot_arctic_seq_tags = []
    for cur_subj_folder in subj_folders:
        full_cur_subj_folder = os.path.join(processed_arctic_root, cur_subj_folder)
        cur_subj_seq_nms = os.listdir(full_cur_subj_folder)
        cur_subj_seq_nms = [fn for fn in cur_subj_seq_nms if fn.endswith(".npy")]
        for cur_subj_seq_nm in cur_subj_seq_nms:
            full_seq_nm = os.path.join(full_cur_subj_folder, cur_subj_seq_nm)
            tot_arctic_seq_paths.append(full_seq_nm)
            cur_seq_tag = f"{cur_subj_folder}_{cur_subj_seq_nm.split('.')[0]}"
            tot_arctic_seq_tags.append(cur_seq_tag)
    return tot_arctic_seq_paths, tot_arctic_seq_tags


if __name__=='__main__':
  
  nn_hand_params = 45
  
  test_tag = "tag1"
  test_tags = [f"tag1", "tag2", "tag3", "tag4", "tag5", "tag6"]
  
  pred_infos_sv_folder = "./save/arctic_save_res"
  new_pred_infos_sv_folder = "./save/arctic_save_res"
  
  # 
  tot_rnd_seeds = range(0, 121, 1)


  tot_arctic_seq_paths, tot_arctic_seq_tags = get_arctic_seq_paths()
  

  for i_test_seq, (cur_seq_path, cur_seq_tag) in enumerate(zip(tot_arctic_seq_paths, tot_arctic_seq_tags)):
    for seed in tot_rnd_seeds:

      pred_joints_info_nm = f"predicted_infos_seq_{i_test_seq}_seed_{seed}_tag_{test_tag}.npy"
      

      tot_data = {}
      for i_tag, cur_test_tag in enumerate(test_tags):
        pred_joints_info_nm = f"predicted_infos_seq_{i_test_seq}_{cur_seq_tag}_seed_{seed}_tag_{cur_test_tag}.npy"
        cur_pred_joints_info_fn = os.path.join(pred_infos_sv_folder, pred_joints_info_nm)
        print(f"cur_pred_joints_info_fn: {cur_pred_joints_info_fn}")
        if not os.path.exists(cur_pred_joints_info_fn):
            continue
        cur_data = np.load(cur_pred_joints_info_fn, allow_pickle=True).item()
        print(f"Data loaded from {cur_pred_joints_info_fn}")
        if i_tag == len(test_tags) - 1:
          for cur_k in cur_data:
              if cur_k not in tot_data:
                if i_tag == len(test_tags) - 1:
                    tot_data[cur_k] = [cur_data[cur_k]]
                else:
                    tot_data[cur_k] = [cur_data[cur_k]]
              else:
                  tot_data[cur_k].append(cur_data[cur_k])
        else:
          for cur_k in cur_data: # fov
              if cur_k not in tot_data:
                if cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
                  tot_data[cur_k] = [cur_data[cur_k][:, :30]]
                else:
                  tot_data[cur_k] = [cur_data[cur_k][ :30]]
              else:
                if cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
                  tot_data[cur_k].append(cur_data[cur_k][:, :30])
                else:
                  tot_data[cur_k].append(cur_data[cur_k][ :30])
      for cur_k in tot_data:
        if cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
            tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
        else:
            tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
        print(f"cur_k: {cur_k}, {tot_data[cur_k].shape}")
      data = tot_data
      print(f"data: {data.keys()}")


      merged_pred_sv_infos_sv_fn_nm = f"predicted_infos_sv_dict_seq_{i_test_seq}_seed_{seed}_tag_{test_tags[0]}_{cur_seq_tag}_multi_ntag_{len(test_tags)}.npy"
      merged_pred_sv_infos_sv_fn = os.path.join(new_pred_infos_sv_folder, merged_pred_sv_infos_sv_fn_nm)
      np.save(merged_pred_sv_infos_sv_fn, data)
    
      
      targets = data['targets'] # # targets # #
      outputs = data['outputs'] #  
      tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
      tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
      
      obj_verts = data["tot_obj_pcs"][0]

      tot_obj_rot = data['tot_obj_rot'][0] # ws x 3 x 3 ---> obj_rot #
      tot_obj_transl = data['tot_obj_transl'][0]
      print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
      
      if len(tot_base_pts.shape) == 2:
        tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
        tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0] 
        
        tot_base_normals_trans = np.matmul(
          tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
        ) 
      else:
        # numpy array #
        print(f"tot_base_pts: {tot_base_pts.shape}, tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
        tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
        tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
        
        tot_base_normals_trans = np.matmul(
          tot_base_normals, tot_obj_rot
        )
      
      outputs = np.matmul(outputs, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
      targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
      print(f"tot_base_pts: {tot_base_pts.shape}")
      
      
      obj_verts_trans = np.matmul(obj_verts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) 
      obj_faces = data['template_obj_fs'] #
      print(f"obj_verts_trans: {obj_verts_trans.shape}, obj_faces: {obj_faces.shape}")

      
      if 'use_left' in test_tag:
          use_left = True
      else:
          use_left = False

      optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs, tot_base_pts, tot_base_pts_trans, tot_base_normals_trans, nn_hand_params=nn_hand_params)

      optimized_sv_infos = {}
      optimized_sv_infos.update(optimized_dict)
      optimized_sv_infos.update(
        {
          'tot_base_pts_trans': tot_base_pts_trans,
          'tot_base_normals_trans': tot_base_normals_trans
        }
      )
      
      optimized_sv_infos_sv_fn_nm = f"optimized_infos_sv_dict_seq_{i_test_seq}_seed_{seed}_tag_{test_tag}_{cur_seq_tag}_wmaskanchors_multi_ntag_{len(test_tags)}.npy" 
      optimized_sv_infos_sv_fn = os.path.join(new_pred_infos_sv_folder, optimized_sv_infos_sv_fn_nm)
      np.save(optimized_sv_infos_sv_fn, optimized_sv_infos)
      print(f"optimized infos saved to {optimized_sv_infos_sv_fn}")
      

